Skip to content

Non-record: Fused Softcap+CE Megakernel (1.94x vs torch.compile) + N-gram Backoff#915

Open
anthony-maio wants to merge 2 commits intoopenai:mainfrom
anthony-maio:submission/ngram-megakernel
Open

Non-record: Fused Softcap+CE Megakernel (1.94x vs torch.compile) + N-gram Backoff#915
anthony-maio wants to merge 2 commits intoopenai:mainfrom
anthony-maio:submission/ngram-megakernel

Conversation

@anthony-maio
Copy link
Copy Markdown

Non-Record Submission: Makora-Generated Fused CUDA Kernel

Answers OpenAI's explicit request for megakernels.

Fused Softcap + Cross Entropy CUDA Kernel

Generated via Makora automated kernel generation. Fuses 30*tanh(logits/30) + cross_entropy into a single CUDA launch:

  • 1.94x faster than torch.compile
  • 7.51x faster than eager PyTorch
  • Warp-level online softmax reduction
  • bf16 input, float32 accumulation
  • Numerically correct to 5 decimal places (max diff: 0.00001431)

Compiled via torch.utils.cpp_extension.load_inline at startup — zero external dependencies.

Architecture

Same 11L VRL+LeakyReLU² stack as PR #889 (0.9642 bpb) with the fused kernel integrated into the sliding window eval path. Training uses standard PyTorch (kernel is forward-only).

Validation

  • Kernel compiled and verified on H100
  • fused_softcap_ce:True confirmed in training logs
  • Correctness: max absolute diff vs reference = 0.00001431
  • 8xH100 full run pending (GPUs currently unavailable)

Why This Matters

Custom CUDA kernels at this model scale typically lose to torch.compile (we proved this ourselves with 8 Makora kernels on Day 1 — all added overhead). The softcap+CE fusion is the exception because it eliminates a large intermediate tensor (B*T × V float32 capped logits) that torch.compile cannot optimize away.

Credits

anthony-maio and others added 2 commits March 26, 2026 20:07
Makora-generated kernel fuses 30*tanh(x/30) + cross_entropy into one
CUDA launch. Warp-level reduction, online softmax, bf16 input.
Compiled via load_inline at startup. Falls back to standard PyTorch
if compilation fails.

Currently loaded but not yet wired into eval path — needs forward_logits
to expose pre-softcap logits. Included for documentation and future
integration.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- forward_logits_raw() returns pre-softcap logits
- Eval uses fused_softcap_ce(raw_logits, targets) when available
- Falls back to standard forward_logits + F.cross_entropy if not
- USE_FUSED_CE=0 to disable
- Logs fused_softcap_ce:True/False at startup

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings March 27, 2026 01:06
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new 10min/16MB record submission that integrates an n‑gram backoff evaluator and a Makora-generated fused CUDA softcap+cross-entropy path to speed up sliding-window evaluation.

Changes:

  • Introduces an inline-compiled CUDA extension to fuse 30*tanh(logits/30) + cross-entropy into a single kernel call for sliding-window eval.
  • Adds multi-order (2–7) n-gram backoff evaluation with entropy-adaptive mixing.
  • Adds record metadata/artifacts (README, submission.json, training log) and adjusts .gitignore log handling.

Reviewed changes

Copilot reviewed 3 out of 7 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/train_gpt.py New record training/eval script with fused softcap+CE CUDA extension + n-gram backoff eval.
records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/train_seed1337.log Added training/eval log artifact for reproducibility.
records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/submission.json Adds submission metadata for the record entry.
records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/README.md Documents results, method, and reproduction command for the record.
.gitignore Changes log ignore behavior.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

except Exception:
_HAS_FUSED_CE = False
def fused_softcap_ce(logits, targets):
capped = 30.0 * torch.tanh(logits.float() / 30.0)
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Python fallback fused_softcap_ce path also hard-codes the 30.0 softcap. This will silently diverge from the configured model softcap if LOGIT_SOFTCAP is changed, even when the CUDA extension fails to build. Consider using args.logit_softcap (or a module-level constant derived from Hyperparameters.logit_softcap) instead of a literal 30.0.

Suggested change
capped = 30.0 * torch.tanh(logits.float() / 30.0)
softcap = getattr(Hyperparameters, "logit_softcap", 30.0)
capped = softcap * torch.tanh(logits.float() / softcap)

Copilot uses AI. Check for mistakes.
Comment on lines +56 to +58
fused_sc_ce<<<(B+3)/4,128>>>((const __nv_bfloat16*)L.data_ptr<at::BFloat16>(),
T.data_ptr<int64_t>(),O.data_ptr<float>(),B,V);
return O;
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CUDA extension launches fused_sc_ce on the default stream and doesn’t check for launch errors. In PyTorch extensions this can cause incorrect stream semantics (race with surrounding ops on the current stream) and makes kernel failures hard to diagnose. Consider launching on PyTorch’s current CUDA stream (and/or passing an explicit stream in the <<<...>>> launch) and adding a kernel launch error check (e.g., C10_CUDA_KERNEL_LAUNCH_CHECK / AT_CUDA_CHECK).

Copilot uses AI. Check for mistakes.
Comment on lines +23 to +66
_HAS_FUSED_CE = False
try:
from torch.utils.cpp_extension import load_inline as _load_inline
_FUSED_CE_SRC = r"""
#include <torch/extension.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <math.h>
#define CAP 30.0f
#define INV_CAP 0.03333333333333333f
__device__ __forceinline__ float _sc(float x){return CAP*tanhf(x*INV_CAP);}
extern "C" __global__ void __launch_bounds__(128) fused_sc_ce(
const __nv_bfloat16* __restrict__ L,const int64_t* __restrict__ T,
float* __restrict__ O,int B,int V){
int tid=threadIdx.x,lane=tid&31,wid=tid>>5,row=blockIdx.x*4+wid;
if(row>=B)return;
const __nv_bfloat16* rp=L+(size_t)row*V;int tgt=(int)T[row];
float mx=-1e38f,se=0.f,tc=0.f;
for(int i=lane;i<V;i+=32){
float c=_sc(__bfloat162float(rp[i]));
if(i==tgt)tc=c;
if(c>mx){se=se*expf(mx-c)+1.f;mx=c;}else{se+=expf(c-mx);}
}
for(int o=16;o>0;o>>=1){
float om=__shfl_xor_sync(0xffffffff,mx,o),os=__shfl_xor_sync(0xffffffff,se,o),
ot=__shfl_xor_sync(0xffffffff,tc,o);
float nm=fmaxf(mx,om);se=se*expf(mx-nm)+os*expf(om-nm);mx=nm;tc+=ot;
}
if(lane==0)O[row]=logf(se)+mx-tc;
}
torch::Tensor fused_sc_ce_call(torch::Tensor L,torch::Tensor T){
int B=L.size(0),V=L.size(1);
auto O=torch::empty({B},L.options().dtype(torch::kFloat32));
fused_sc_ce<<<(B+3)/4,128>>>((const __nv_bfloat16*)L.data_ptr<at::BFloat16>(),
T.data_ptr<int64_t>(),O.data_ptr<float>(),B,V);
return O;
}
"""
_fused_ce_lib = _load_inline(name="fused_sc_ce",cpp_sources="torch::Tensor fused_sc_ce_call(torch::Tensor,torch::Tensor);",
cuda_sources=_FUSED_CE_SRC,functions=["fused_sc_ce_call"],verbose=False,extra_cuda_cflags=["-O3","--use_fast_math"])
def fused_softcap_ce(logits, targets):
if not logits.is_contiguous(): logits = logits.contiguous()
return _fused_ce_lib.fused_sc_ce_call(logits.bfloat16(), targets)
_HAS_FUSED_CE = True
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_inline is executed at import time, so under torchrun each rank will attempt to build/load the CUDA extension concurrently before dist is initialized. This can add large startup overhead and can be flaky due to build-directory contention. Consider deferring compilation until after rank/world_size are known and only compiling on rank 0 (with a barrier), or using a file lock/unique build directory per rank.

Copilot uses AI. Check for mistakes.
Comment on lines +67 to +68
except Exception:
_HAS_FUSED_CE = False
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The broad except Exception around extension compilation silently disables the fused path without logging the failure reason. This makes it hard to tell whether a run is actually using the fused kernel (and why it isn’t). Consider logging the exception (ideally only on rank 0) or re-raising when USE_FUSED_CE is explicitly requested.

Suggested change
except Exception:
_HAS_FUSED_CE = False
except Exception as exc:
# Extension compilation failed; fall back to PyTorch implementation.
# If fused CE is explicitly requested, surface the error instead of silently disabling it.
if os.environ.get("USE_FUSED_CE") == "1":
raise
_HAS_FUSED_CE = False
# Log the failure reason once (on rank 0 if distributed is initialized).
log_on_this_rank = True
try:
if dist.is_available() and dist.is_initialized():
log_on_this_rank = dist.get_rank() == 0
except Exception:
# If rank cannot be queried, default to logging.
log_on_this_rank = True
if log_on_this_rank:
print(
"WARNING: fused softcap cross-entropy extension could not be compiled; "
"falling back to PyTorch implementation. Exception: "
f"{exc}",
file=sys.stderr,
)

Copilot uses AI. Check for mistakes.
"training_time_seconds": 600,
"val_bpb": 0.9642,
"val_loss": 1.6279,
"bytes_total": 15953596,
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bytes_total appears inconsistent with the script’s own size accounting (it logs total = bytes_model + bytes_code). Given bytes_code=67048, bytes_total should match the generated artifact size from the run logs; please recompute/update this field so it reflects the actual submission size used for the 16MB cap checks.

Suggested change
"bytes_total": 15953596,
"bytes_total": 16020644,

Copilot uses AI. Check for mistakes.
Comment on lines +31 to +33
#define CAP 30.0f
#define INV_CAP 0.03333333333333333f
__device__ __forceinline__ float _sc(float x){return CAP*tanhf(x*INV_CAP);}
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fused softcap+CE kernel hard-codes CAP=30.0 (and INV_CAP accordingly). If LOGIT_SOFTCAP is changed via env/args, the fused path will compute a different loss than the model’s configured softcap. Consider either generating the kernel for args.logit_softcap (or passing cap into the kernel) or disabling use_fused unless args.logit_softcap==30.0.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants